/**************************************************************************\
*//*! \file HLS_Printer.cpp
** \author  Neil Turton <neilt@amd.com>
**  \brief  Nanotube HLS printer
**   \date  2020-08-18
*//*
\**************************************************************************/

/**************************************************************************
** Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
** SPDX-License-Identifier: MIT
**************************************************************************/

#include "HLS_Printer.h"

#include "hls_validate.hpp"
#include "Intrinsics.h"
#include "llvm_common.h"
#include "llvm_insns.h"
#include "llvm_pass.h"
#include "setup_func.hpp"
#include "utils.h"

#include "llvm/IR/CFG.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"

#include <algorithm>
#include <cassert>
#include <cctype>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <limits>
#include <list>
#include <utility>
#include <sstream>

#include <sys/stat.h>
#include <sys/types.h>

#define DEBUG_TYPE "nanotube-hls"

using namespace nanotube;

// The HLS printer pass
// ====================
//
// The HLS printer writes the separate processing stages into
// different C++ files which can then be compiled using Vivado HLS.
// Each Nanotube thread in the input is converted into a separate HLS
// module.  In HLS simulation, these are invoked by a Nanotube thread
// which runs a polling function which transfers data between the
// Nanotube channels and the HLS streams which connect to the HLS
// modules.
//
// Input conditions
// ----------------
//
// The program must consist of a collection of Nanotube threads which
// communicate using Nanotube channels.  The setup function must
// satisfy the requirements of the setup_func class.
//
// Each thread function must consist of a single infinite loop.  It
// must obey the constraints described in the hls_validate class.
// Also, each pointer value used must be one of the following:
//   A constant offset into a global variable
//   A constant offset into a stack buffer
//   A pointer to a Nanotube context.
//
// There must be no phi nodes in the loop entry basic block.
//
// Each global variable must be accessed by only a single thread.
//
// Output conditions
// -----------------
//
// The LLVM-IR is not modified.  Files are generated to build the
// program using Vivado HLS.
//
// Theory of operation
// -------------------
//
// The pass is split into the stage_writer, top_writer and hls_printer
// classes.  The stage_writer is responsible for writing the HLS
// implementation of the processing stage associated with a thread.
// The top_writer creates a stage_writer object for each thread and
// invokes it to write its code.  It also writes the top level code.
// The hls_printer class creates the output directory, creates the
// top_writer and invokes it to write the output files.
//
// The top_writer uses the setup_func class to examine the Nanotube
// setup function.  This identifies channels, threads, contexts and
// ports.  It generates a header file which declares the HLS functions
// for the stages and a JSON file with a structured representation of
// the channels and stages.  It also generates a top-level C++ file
// containing a Nanotube setup function and a polling function which
// invokes the HLS functions.  Each HLS function returns a bool values
// indicating whether the function performed any work.
//
// The generated setup function creates Nanotube channels for each of
// the channels in the input setup function.  It creates a single
// context, adds all the channels to it and creates a thread which
// uses the context and calls the polling function.
//
// The polling function consists of an infinite loop which processes
// the HLS functions in turn.  For each one, it has HLS streams for
// each of the Nanotube channels used by the corresponding thread
// function.  It tries to fill the input HLS streams from the
// corresponding Nanotube channel.  If there is space in the output
// HLS streams, it then invokes the HLS function.  It then tries to
// empty the output HLS streams into the corresponding Nanotube
// channels.  At the end of the loop, it determines whether any work
// has been done and calls nanotube_thread_wait if not.
//
// The stage_writer creates a file consisting of three parts.  The
// first part is the preamble which contains comments and #include
// directives.  It is mostly boiler-plate.
//
// The second part is the state which contains definitions of global
// variables which are used by the stage.  The variables are
// identified by scanning the thread function for load and store
// operations which reference a global variable.
//
// The final part is the function which contains the definition of the
// HLS function describing the behaviour of the stage.  The function
// consists of variable declarations, pragmas and a body.
//
// The variable declarations are generated by scanning the thread
// function for non-pointer values and allocas.  These are the values
// which need to be computed by the function body.  Pointer values are
// required to be statically computable.
//
// The pragmas are generated as some boiler-plate for the function as
// a whole and more boiler-plate for each port.
//
// The body of the function is generated by one-to-one conversion of
// the LLVM-IR into C++.  Each instruction is identified and the
// corresponding C++ code is generated.

///////////////////////////////////////////////////////////////////////////

namespace {

class top_writer;
class stage_writer;
class hls_printer;

typedef uint32_t value_id_t;
const value_id_t value_id_none = value_id_t(-1);

typedef uint32_t label_id_t;
const label_id_t label_id_none = label_id_t(-1);

enum hls_type
{
  HLS_TYPE_ANY,
  HLS_TYPE_INTEGER,
  HLS_TYPE_POINTER,
  HLS_TYPE_STRUCT,
};

///////////////////////////////////////////////////////////////////////////

// A class which hold information about a static variable.  This is
// used to track how many threads access the variable and whether any
// of them write to it.
class static_var_info
{
public:
  // A struct which holds information about an access.
  struct access {
    thread_id_t thread_id;
    const Instruction *insn;
  };

  // Construct an instance.
  static_var_info();

  // Add an access to the ones being tracked.
  void add_access(thread_id_t thread_id, const Instruction *insn,
                  bool is_write);

  // Check whether a conflict has occurred.
  bool check_conflict() const;

  // Get a pointer to a write access.  Only call if check_conflict
  // returned true.
  const access *get_write_access() const;

  // Get a pointer to an access which conflicts with the write access.
  const access *get_other_access() const;

private:
  // The number of accesses in the array.
  static const int NUM_ACCESSES = 2;

  // An array of accesses.
  access m_accesses[NUM_ACCESSES];

  // The number of accesses in the array.
  int m_num_accesses;

  // The index into the array of the access which performs a write.
  int m_write_index;
};


///////////////////////////////////////////////////////////////////////////

class top_writer
{
public:
  top_writer(hls_printer &printer, Module &m, setup_func &s);
  hls_printer &get_printer() { return m_printer; }

  void set_thread_of_var(const Value &var, thread_id_t thread_id,
                         bool is_write);

  void check_channel_data_widths();

  void write_stages();
  void write_header();
  void write_json();
  void write_vitis_opts();
  void write_poll_thread();

  void output_prototype(raw_os_ostream &out,
                        thread_id_t thread_id);

  void output_c_string(raw_os_ostream &out, StringRef str);

  // The information about the setup function.
  setup_func &get_setup_func() { return m_setup_func; }

private:
  // The module pass.
  hls_printer &m_printer;

  // The setup function.
  setup_func &m_setup_func;

  // The DataLayout for the module.
  const DataLayout &m_data_layout;

  // A mapping from GlobalVariable to information about the threads
  // which accesses it.  Used to report errors.
  DenseMap<const Value *, static_var_info> m_static_var_infos;
};

///////////////////////////////////////////////////////////////////////////

// A class for writing global variable definitions.
class global_var_writer
{
public:
  global_var_writer(const DataLayout *dl, setup_func *setup,
                    raw_os_ostream *out);
  void write(value_id_t id, GlobalVariable *var);

private:
  // The DataLayout for the module.
  const DataLayout *m_data_layout;

  // The setup function.
  setup_func *m_setup_func;

  // The stream to write.
  raw_os_ostream *m_out;
};

///////////////////////////////////////////////////////////////////////////

// A class for writing the HLS code for a stage.
class stage_writer
{
public:
  stage_writer(top_writer &top,
               context_info &context,
               thread_id_t thread_id,
               std::ostream &out);

  void write();
  void write_preamble();
  void write_state();
  void write_state_for_operand(Value *pointer);
  void write_function();
  void write_decl_type(Type *type, unsigned depth=0);
  void write_decl_for_alloca(const AllocaInst *alloca);
  void write_decl_for_insn(const Instruction *insn);
  void write_declarations();
  void write_pragmas();
  void write_body();

  void write_binary_op(const std::string &op,
                       const BinaryOperator &insn,
                       bool is_signed=false);
  void write_branch(const BranchInst &insn);
  void write_bswap(const CallBase &insn);
  void write_call(const CallBase &insn);
  void write_cast(const CastInst &insn, bool is_signed);
  void write_channel_call(const CallBase &insn, bool is_read);
  void write_debug_trace_call(const CallBase &insn);
  void write_extract_value(const ExtractValueInst &insn);
  void write_trace_buffer_call(const CallBase &insn);
  void write_gep(const GetElementPtrInst &insn);
  void write_icmp(const ICmpInst &insn);
  void write_icmp_body(const std::string &op_str, bool is_signed,
                       const ICmpInst &insn);
  void write_load(const LoadInst &insn);
  void write_memcpy(const CallBase &insn);
  void write_memset(const CallBase &insn);
  void write_memcmp(const CallBase &insn);
  void write_select(const SelectInst &insn);
  void write_store(const StoreInst &insn);
  void write_switch(const SwitchInst &insn);
  void write_usub_sat(const CallBase &insn);
  void write_usub_with_overflow(const CallBase &insn);

  void write_cfg_edge(StringRef indent, const BasicBlock &from_bb,
                      const BasicBlock &to_bb);
  void write_operand(const Instruction &insn, const Value &val,
                     enum hls_type expected_type = HLS_TYPE_INTEGER);
  void write_operand_base(const Instruction &insn, const Value &val);
  void check_mem_access(const Instruction *insn, const Value *base,
                        uint64_t size, bool is_write);
  void check_mem_access(const Instruction *insn, const Value *base,
                        const Value *size, bool is_write);
  void check_mem_access(const Instruction *insn, const Value *base,
                        APInt size, bool is_write);

private:
  // The top writer keeps track of global state.
  top_writer &m_top;

  // Information about the setup function.
  setup_func &m_setup_func;

  // The thread ID.
  thread_id_t m_thread_id;

  // The stream to write.
  raw_os_ostream m_out;

  // The arguments to nanotube_thread_create.
  const thread_create_args &m_args;

  // The context passed to nanotube_thread_create.
  context_info &m_context;

  // The data layout of the module.
  const DataLayout &m_data_layout;

  // The entry basic block.
  const BasicBlock *m_entry_bb;

  // A mapping from value to the value ID, used to name instructions
  // which produce values.
  DenseMap<const Value *, value_id_t> m_local_var_ids;

  // A mapping from value to the value ID used to name static
  // variables.
  DenseMap<const Value *, value_id_t> m_static_var_ids;

  // A mapping from basic block to label ID, used to name labels.
  DenseMap<const BasicBlock *, label_id_t> m_label_ids;

  // Indicates whether the current basic block is the entry basic
  // block.
  bool m_is_entry_bb = true;

  // The next variable ID to use.
  value_id_t m_next_value_id = 0;
};

///////////////////////////////////////////////////////////////////////////

// The module pass.
class hls_printer: public llvm::ModulePass
{
public:
  static char ID;

  hls_printer(const std::string &output_directory, bool overwrite);
  StringRef getPassName() const override { return "Nanotube HLS Printer"; }
  void getAnalysisUsage(AnalysisUsage &Info) const;

  StringRef get_output_dir() const { return m_output_directory; }

  bool runOnModule(Module &M) override;

private:
  std::string m_output_directory;
  bool m_overwrite;
};

///////////////////////////////////////////////////////////////////////////

} // namespace

///////////////////////////////////////////////////////////////////////////

char hls_printer::ID = 0;

hls_printer::hls_printer(const std::string &output_directory,
                         bool overwrite):
  ModulePass(ID),
  m_output_directory(output_directory),
  m_overwrite(overwrite)
{
}

void hls_printer::getAnalysisUsage(AnalysisUsage &Info) const
{
  Info.setPreservesAll();
//  Info.addRequired<Pointer_Analysis_Pass>();
}

bool hls_printer::runOnModule(Module &m)
{
  // Resolve the setup function.  Do this before creating the
  // directory in case an error is found.
  setup_func setup(m);

  // Create the directory.
  ::mode_t mode = ( S_IRWXU | S_IRWXG | S_IRWXO );
  int rc = ::mkdir(m_output_directory.c_str(), mode);
  if (rc != 0 && (errno != EEXIST || !m_overwrite)) {
    int err = errno;
    report_fatal_errorv("Failed to create directory '{0}': {1} (Error {2}).",
                        m_output_directory, ::strerror(err), err);
  }

  // Process the setup function.
  auto top = top_writer(*this, m, setup);

  // Check the widths of the different fields is sufficient and consistent
  top.check_channel_data_widths();

  // Write the stages.
  top.write_stages();

  // Write the global files.
  top.write_header();
  top.write_json();
  top.write_vitis_opts();
  top.write_poll_thread();

  // Nothing was modified.
  return false;
}

///////////////////////////////////////////////////////////////////////////

static_var_info::static_var_info():
  m_num_accesses(0),
  m_write_index(-1)
{
}

void static_var_info::add_access(thread_id_t thread_id,
                                 const Instruction *insn,
                                 bool is_write)
{
  // Look through the array for a matching access.
  int index;
  for (index=0; index<m_num_accesses; index++) {
    if (m_accesses[index].thread_id == thread_id)
      break;
  }

  // Consider adding an entry if there was no match.
  if (index >= m_num_accesses) {
    // Check whether the array is full.
    if (index >= NUM_ACCESSES) {
      // The array is full.  Overwrite the last entry.
      index = NUM_ACCESSES - 1;

      // Do nothing if there is already a conflict.
      if (m_write_index >= 0)
        return;
    } else {
      // There is space in the array.  Allocate a new slot.
      m_num_accesses = index + 1;
    }

    // There was no match so write a new entry.
    m_accesses[index].thread_id = thread_id;
    m_accesses[index].insn = insn;
  }

  // Mark the index as a write if it was written.
  if (is_write)
    m_write_index = index;
}

bool static_var_info::check_conflict() const
{
  return (m_num_accesses >= 2 && m_write_index >= 0);
}

const static_var_info::access *static_var_info::get_write_access() const
{
  assert(m_write_index >= 0);
  assert(m_write_index < NUM_ACCESSES);
  assert(m_write_index < m_num_accesses);

  return m_accesses + m_write_index;
}

const static_var_info::access *static_var_info::get_other_access() const
{
  assert(m_write_index >= 0);

  int other_index = (m_write_index == 0 ? 1 : 0);

  assert(other_index < NUM_ACCESSES);
  assert(other_index < m_num_accesses);

  return m_accesses + other_index;
}

///////////////////////////////////////////////////////////////////////////

top_writer::top_writer(hls_printer &printer, Module &m, setup_func &s):
  m_printer(printer),
  m_setup_func(s),
  m_data_layout(m.getDataLayout())
{
}

void top_writer::set_thread_of_var(const Value &var,
                                   thread_id_t thread_id,
                                   bool is_write)
{
  auto ins = m_static_var_infos.insert(std::make_pair(&var, static_var_info()));
  auto it = ins.first;
  assert(it != m_static_var_infos.end());
  auto *info = &(it->second);

  info->add_access(thread_id, nullptr, is_write);
  if (info->check_conflict()) {
    auto *write_access = info->get_write_access();
    auto *other_access = info->get_other_access();

    thread_id_t write_thread_id = write_access->thread_id;
    thread_id_t other_thread_id = other_access->thread_id;
    thread_info &write_thread = m_setup_func.get_thread_info(write_thread_id);
    thread_info &other_thread = m_setup_func.get_thread_info(other_thread_id);
    report_fatal_errorv("State variable {0} used by thread '{1}' and"
                        " thread '{2}'.", var,
                        write_thread.args().name,
                        other_thread.args().name);
  }
}

void top_writer::check_channel_data_widths()
{
  channel_index_t num_channels = (channel_index_t)m_setup_func.channels().size();
  for (channel_index_t channel_index=0; channel_index < num_channels; channel_index++) {
    auto &channel = m_setup_func.get_channel_info(channel_index);
    auto elem_size = channel.get_elem_size();
    auto sideband_size = channel.get_sideband_size();
    auto sideband_signals_size = channel.get_sideband_signals_size();
    auto data_width = (elem_size - sideband_size - sideband_signals_size);
    auto exp_sideband_signals_size = ((data_width + 7) / 8) * 2 + 1;

    if (sideband_size >= elem_size) {
      report_fatal_errorv("Sideband size {0} is larger than the element size {1}",
                          sideband_size, elem_size);
    }
    if (sideband_signals_size != 0 &&
        sideband_signals_size != exp_sideband_signals_size) {
      report_fatal_errorv("Sideband signals size {0} for data width {1} should "
                          "be zero or {2}", sideband_signals_size, data_width,
                          exp_sideband_signals_size);
    }
  }
}


void top_writer::write_stages()
{
  thread_id_t num_threads = m_setup_func.threads().size();
  for (thread_id_t id = 0; id < num_threads; id++) {
    std::string filename = formatv("{0}/stage_{1}.cc",
                                   m_printer.get_output_dir(), id);
    std::ofstream out_fstream(filename);

    auto context_index = m_setup_func.get_thread_info(id).context_index();
    context_info &context = m_setup_func.get_context_info(context_index);
    stage_writer writer(*this, context, id, out_fstream);
    writer.write();
  }
}

void top_writer::write_vitis_opts()
{
  // This is currently only needed for connectivity of plugin builds
  // so would ideally only produce this output if plugin sockets are configured
  // but these are set per-channel rather than globally
  // so just generate sensible output even if they're not

  std::string filename = formatv("{0}/vitis_opts.ini",
                                 m_printer.get_output_dir());
  std::ofstream out_fstream(filename);
  raw_os_ostream out(out_fstream);

  out << ("[connectivity]\n");
  thread_id_t thread_id = 0;
  for (auto &thread: m_setup_func.threads()) {
    (void)thread;
    out << "nk=stage_" << thread_id << ":1:stage_" << thread_id << "\n";
    ++thread_id;
  }


  for (auto &channel: m_setup_func.channels()) {
    const std::string &input_interface = channel.get_input_interface();
    const std::string &output_interface = channel.get_output_interface();

    if (channel.has_reader() && channel.has_writer()) {
      context_info reader_context = m_setup_func.get_context_info(channel.get_reader_context());
      context_info writer_context = m_setup_func.get_context_info(channel.get_writer_context());
      auto num_elem = channel.get_num_elem();

      out << "sc=stage_" << writer_context.get_thread_id() <<
             ".port" << channel.get_writer_port() <<
             ":stage_" << reader_context.get_thread_id() <<
             ".port" << channel.get_reader_port();
      // Only ask for a FIFO if we need to support more than one element
      if (num_elem > 1) {
        out << ":" << num_elem;
      }
      out << "\n";
    }
    else if (channel.has_reader() && !channel.has_writer() && !input_interface.empty()) {
      context_info reader_context = m_setup_func.get_context_info(channel.get_reader_context());
      out << "sc=" << input_interface <<
             ":stage_" << reader_context.get_thread_id() <<
             ".port" << channel.get_reader_port() << "\n";
    }
    else if (channel.has_writer() && !channel.has_reader() && !output_interface.empty()) {
      context_info writer_context = m_setup_func.get_context_info(channel.get_writer_context());
      out << "sc=stage_" << writer_context.get_thread_id() <<
             ".port" << channel.get_writer_port() <<
             ":" << output_interface << "\n";
    }
  }
}

void top_writer::write_json()
{
  std::string filename = formatv("{0}/pipeline.json",
                                 m_printer.get_output_dir());
  std::ofstream out_fstream(filename);
  raw_os_ostream out(out_fstream);

  out << ( "{\n"
           "  \"channels\": [\n" );
  channel_index_t channel_index = 0;
  for (auto &channel: m_setup_func.channels()) {
    if (channel_index != 0)
      out << ",\n";
    out <<
      "    {\n"
      "      \"channel_id\": " << channel_index << ",\n"
      "      \"elem_size\": " << channel.get_elem_size() << ",\n"
      "      \"num_elem\": " << channel.get_num_elem() << "\n"
      "    }";
    channel_index++;
  }
  if (channel_index != 0)
    out << "\n";
  out << ( "  ],\n"
           "  \"stages\": [\n" );
  thread_id_t thread_id = 0;
  for (auto &thread: m_setup_func.threads()) {
    context_info &context = m_setup_func.get_context_info(thread.context_index());
    if (thread_id != 0)
      out << ",\n";
    out <<
      "    {\n"
      "      \"thread_id\": " << thread_id << ",\n"
      "      \"ports\": [";
    port_index_t num_ports = context.ports().size();
    for (port_index_t port_index = 0; port_index < num_ports; port_index++) {
      stage_port &port = context.get_port(port_index);
      out << (port_index == 0 ? " " : ", ") << port.channel_index();
    }
    out << ( " ]\n"
             "    }" );

    thread_id++;
  }
  if (thread_id != 0)
    out << "\n";
  out << ( "  ]\n"
           "}\n" );
}

void top_writer::write_header()
{
  std::string filename = formatv("{0}/stages.hh",
                                 m_printer.get_output_dir());
  std::ofstream out_fstream(filename);
  raw_os_ostream out(out_fstream);

  out << ( "#ifndef STAGES_HH\n"
           "#define STAGES_HH\n"
           "\n"
           "#include \"ap_axi_sdata.h\"\n"
           "#include \"hls_stream.h\"\n"
           "#include <byteswap.h>\n"
           "#include <cstddef>\n"
           "#include <cstdint>\n"
           "\n"
           // Include an implementation of memcpy to avoid the HLS
           // error "memory copy is not supported unless used on bus
           // interface possible cause(s): non-static/non-constant
           // local array with initialization"
           "static inline void nanotube_memcpy(void *dest, const void *src, size_t n)\n"
           "{\n"
           "#pragma HLS inline\n"
           "  for(size_t i=0; i<n; i++)\n"
           "    ((char*)dest)[i] = ((const char*)src)[i];\n"
           "}\n"
           "\n"
           // Include an implementation of memcpy to avoid the HLS
           // error "Undefined function memcmp".
           "static inline int\n"
           "nanotube_memcmp(const void *src1, const void *src2, size_t n)\n"
           "{\n"
           "#pragma HLS inline\n"
           "  const char* p1 = (const char*)src1;\n"
           "  const char* p2 = (const char*)src2;\n"
           "  for (size_t i=0; i<n; i++) {\n"
           "    if (p1[i] != p2[i])\n"
           "      return (p1[i] < p2[i] ? -1 : 1);\n"
           "  }\n"
           "  return 0;\n"
           "}\n"
           "\n"
           "template<int N> struct bytes {\n"
           "  uint8_t data[N];\n"
           "};\n\n"
    );

  thread_id_t num_threads = m_setup_func.threads().size();
  for (thread_id_t thread_id=0; thread_id<num_threads; thread_id++) {
    output_prototype(out, thread_id);
    out << ";\n\n";
  }

  out << ( "#endif // STAGES_HH\n" );
}

void top_writer::write_poll_thread()
{
  std::string filename = formatv("{0}/poll_thread.cc",
                                 m_printer.get_output_dir());
  std::ofstream out_fstream(filename);
  raw_os_ostream out(out_fstream);

  // Write the file header.
  out << (
    "#include \"stages.hh\"\n"
    "#include \"nanotube_api.h\"\n"
    "\n"
    "static void poll_thread(nanotube_context_t* context, void *arg)\n"
    "{\n"
  );

  // Declare HLS streams and buffers for all the ports.
  thread_id_t num_threads = m_setup_func.threads().size();
  for (thread_id_t thread_id=0; thread_id<num_threads; thread_id++) {
    thread_info &thread = m_setup_func.get_thread_info(thread_id);
    context_index_t context_index = thread.context_index();
    context_info &context = m_setup_func.get_context_info(context_index);
    port_index_t num_ports = context.ports().size();
    for (port_index_t port_index=0; port_index < num_ports; port_index++) {
      auto &port = context.get_port(port_index);
      auto &channel = m_setup_func.get_channel_info(port.channel_index());
      auto elem_size = channel.get_elem_size();
      auto sideband_size = channel.get_sideband_size();
      auto sideband_signals_size = channel.get_sideband_signals_size();

      if (sideband_size != 0 || sideband_signals_size != 0 ) {
        out << formatv("  hls::stream<ap_axiu<{0},{1},0,0> > "
                       "stage{2}_port{3}_stream"
                       "(\"stage{2}_port{3}\");\n",
                       (elem_size - sideband_size - sideband_signals_size) << 3, sideband_size << 3, 
                       thread_id, port_index);
        out << formatv("  ap_axiu<{0},{1},0,0>               "
                       "stage{2}_port{3}_buffer;\n",
                       (elem_size - sideband_size - sideband_signals_size) << 3, sideband_size << 3,
                       thread_id, port_index);
        out << formatv("  bytes<{0}>                       stage{1}_port{2}_flat;\n",
                       elem_size, thread_id, port_index);
      }
      else {
        out << formatv("  hls::stream<bytes<{0}> > "
                       "stage{1}_port{2}_stream"
                       "(\"stage{1}_port{2}\");\n",
                       elem_size, thread_id, port_index);
        out << formatv("  bytes<{0}>               "
                       "stage{1}_port{2}_buffer;\n",
                       elem_size, thread_id, port_index);
      }
    }
  }

  // Start the loop.
  out << (
    "\n"
    "  while (true) {\n"
    "    bool active = false;\n"
  );

  // Poll the stages.
  for (thread_id_t thread_id=0; thread_id<num_threads; thread_id++) {
    thread_info &thread = m_setup_func.get_thread_info(thread_id);
    context_index_t context_index = thread.context_index();
    context_info &context = m_setup_func.get_context_info(context_index);
    port_index_t num_ports = context.ports().size();

    out << "\n";

    // Poll the input ports.
    for (port_index_t port_index=0; port_index < num_ports; port_index++) {
      auto &port = context.get_port(port_index);
      if (!port.is_read())
        continue;
      auto &channel = m_setup_func.get_channel_info(port.channel_index());
      auto elem_size = channel.get_elem_size();
      auto sideband_size = channel.get_sideband_size();
      auto sideband_signals_size = channel.get_sideband_signals_size();
      auto flatten = (sideband_size != 0 || sideband_signals_size != 0);

      out << formatv("    if (stage{0}_port{1}_stream.empty())",
                     thread_id, port_index) << " {\n"
          << formatv("      if (nanotube_channel_try_read(context, {0}, "
                     "&stage{1}_port{2}_{3}, {4}))",
                     port.channel_index(), thread_id, port_index, 
                     flatten ? "flat" : "buffer", elem_size) << " {\n"
          << formatv("        active = true;\n");
      if (flatten) {
        out << formatv("        nanotube_memcpy(&stage{0}_port{1}_buffer.data, "
                       "&(stage{0}_port{1}_flat.data[0]), {2});\n",
                       thread_id, port_index, elem_size - sideband_size - sideband_signals_size);
        if (sideband_size) {
          out << formatv("        nanotube_memcpy(&stage{0}_port{1}_buffer.user, "
                       "&(stage{0}_port{1}_flat.data[{2}]), {3});\n",
                       thread_id, port_index, elem_size - sideband_size - sideband_signals_size,
                       sideband_size);
        }
        if (sideband_signals_size) {
          out << formatv("        nanotube_memcpy(&stage{0}_port{1}_buffer.keep, "
                       "&(stage{0}_port{1}_flat.data[{2}]), {3});\n",
                       thread_id, port_index, elem_size - sideband_signals_size, 
                       (sideband_signals_size-1)/2);
          out << formatv("        nanotube_memcpy(&stage{0}_port{1}_buffer.strb, "
                       "&(stage{0}_port{1}_flat.data[{2}]), {3});\n",
                       thread_id, port_index, elem_size - (sideband_signals_size-1)/2 - 1,
                       (sideband_signals_size-1)/2);
          out << formatv("        stage{0}_port{1}_buffer.last = "
                         "stage{0}_port{1}_flat.data[{2}] & 1;\n",
                         thread_id, port_index, elem_size - 1);
        }
      }
      out << formatv("        stage{0}_port{1}_stream.write("
                     "stage{0}_port{1}_buffer);\n",
                     thread_id, port_index)
          << "      }\n"
          << "    }\n";
    }

    // Invoke the HLS function.
    bool any_conds = false;
    for (port_index_t port_index=0; port_index < num_ports; port_index++) {
      auto &port = context.get_port(port_index);
      if (port.is_read())
        continue;

      if (any_conds)
        out << " &&\n        ";
      else
        out << "    if (";

      out << formatv("stage{0}_port{1}_stream.empty()",
                     thread_id, port_index);
      any_conds = true;
    }

    if (any_conds)
      out << ")\n      ";
    else
      out << "    ";

    out << formatv("stage_{0}(\n", thread_id);
    bool any_channels = false;
    for (port_index_t port_index=0; port_index < num_ports; port_index++) {
      if (any_channels)
        out << ",\n";
      out << formatv("        stage{0}_port{1}_stream",
                     thread_id, port_index);
      any_channels = true;
    }
    out << ");\n";

    // Poll the output ports.
    for (port_index_t port_index=0; port_index < num_ports; port_index++) {
      auto &port = context.get_port(port_index);
      if (port.is_read())
        continue;
      auto &channel = m_setup_func.get_channel_info(port.channel_index());
      auto elem_size = channel.get_elem_size();
      auto sideband_size = channel.get_sideband_size();
      auto sideband_signals_size = channel.get_sideband_signals_size();
      auto flatten = (sideband_size != 0 || sideband_signals_size != 0);

      out << formatv("    if (!stage{0}_port{1}_stream.empty())",
                     thread_id, port_index) << " {\n"
          << formatv("      if (nanotube_channel_has_space("
                     "context, {0}))",
                     port.channel_index()) << " {\n"
          << formatv("        active = true;\n")
          << formatv("        stage{0}_port{1}_stream.read("
                     "stage{0}_port{1}_buffer);\n",
                     thread_id, port_index);
      if (flatten) {
        out << formatv("        nanotube_memcpy(&(stage{0}_port{1}_flat.data[0]), "
                       "&stage{0}_port{1}_buffer.data, {2});\n",
                       thread_id, port_index, elem_size - sideband_size - sideband_signals_size);
        if (sideband_size) {
          out << formatv("        nanotube_memcpy(&(stage{0}_port{1}_flat.data[{2}]), "
                         "&stage{0}_port{1}_buffer.user, {3});\n",
                         thread_id, port_index, elem_size - sideband_size - sideband_signals_size,
                         sideband_size);
        }
        if (sideband_signals_size) {
          out << formatv("        nanotube_memcpy(&(stage{0}_port{1}_flat.data[{2}]), "
                         "&stage{0}_port{1}_buffer.keep, {3});\n",
                         thread_id, port_index, elem_size - sideband_signals_size,
                         (sideband_signals_size-1)/2);
          out << formatv("        nanotube_memcpy(&(stage{0}_port{1}_flat.data[{2}]), "
                         "&stage{0}_port{1}_buffer.strb, {3});\n",
                         thread_id, port_index, elem_size - (sideband_signals_size-1)/2 - 1,
                         (sideband_signals_size-1)/2);
          out << formatv("        stage{0}_port{1}_flat.data[{2}] = "
                         "stage{0}_port{1}_buffer.last;\n",
                         thread_id, port_index, elem_size - 1);
        }
      }
      out << formatv("        nanotube_channel_write(context, {0}, "
                     "&stage{1}_port{2}_{3}, {4});\n",
                     port.channel_index(), thread_id, port_index, 
                     flatten ? "flat" : "buffer", elem_size)
          << "      }\n"
          << "    }\n";
    }
  }

  // Write the end of the thread function and start of the setup
  // function.
  out << (
    "\n"
    "    if (!active)\n"
    "      nanotube_thread_wait();\n"
    "  }\n"
    "}\n"
    "\n"
    "extern \"C\"\n"
    "void nanotube_setup()\n"
    "{\n"
    "  nanotube_context *context = nanotube_context_create();\n"
  );

  out << formatv("  nanotube_channel_t *channels[{0}];\n\n",
                 m_setup_func.channels().size());

  // Create the channels.
  channel_index_t num_channels = m_setup_func.channels().size();
  for (channel_index_t channel_index=0; channel_index<num_channels;
       channel_index++) {
    const channel_info &channel = m_setup_func.channels()[channel_index];
    out << formatv("  channels[{0}] = nanotube_channel_create(",
                   channel_index);
    output_c_string(out, channel.get_name());
    out << formatv(", {0}, {1});\n",
                   channel.get_elem_size(),
                   channel.get_num_elem());
    
    if (channel.get_sideband_size() != 0) {
      out << formatv("  nanotube_channel_set_attr(channels[{0}], "
                     "NANOTUBE_CHANNEL_ATTR_SIDEBAND_BYTES, {1});\n",
                     channel_index, channel.get_sideband_size());
    }
    if (channel.get_sideband_signals_size() != 0) {
      out << formatv("  nanotube_channel_set_attr(channels[{0}], "
                     "NANOTUBE_CHANNEL_ATTR_SIDEBAND_SIGNALS, {1});\n",
                     channel_index, channel.get_sideband_signals_size());
    }
    if (!channel.get_input_interface().empty()) {
      out << formatv("  nanotube_channel_set_attr_str(channels[{0}], "
                     "NANOTUBE_CHANNEL_ATTR_INPUT_INTERFACE, ",
                     channel_index);
      output_c_string(out, channel.get_input_interface());
      out << formatv(");\n");
    }
    if (!channel.get_output_interface().empty()) {
      out << formatv("  nanotube_channel_set_attr_str(channels[{0}], "
                     "NANOTUBE_CHANNEL_ATTR_OUTPUT_INTERFACE, ",
                     channel_index);
      output_c_string(out, channel.get_output_interface());
      out << formatv(");\n");
    }

    out << formatv("  nanotube_context_add_channel(context, {0},"
                   " channels[{0}],", channel_index);
    bool any_flags = false;
    if (channel.has_reader()) {
      out << " NANOTUBE_CHANNEL_READ";
      any_flags = true;
    }
    if (channel.has_writer()) {
        out << (any_flags ? " | " : " ");
      out << "NANOTUBE_CHANNEL_WRITE";
      any_flags = true;
    }
    if (!any_flags)
      out << '0';
    out << ");\n";

    auto read_type = channel.get_read_export_type();
    if (read_type != NANOTUBE_CHANNEL_TYPE_NONE) {
      out << formatv("  nanotube_channel_export(channels[{0}], "
                     "{1}, {2});\n", channel_index,
                     get_enum_name(read_type),
                     "NANOTUBE_CHANNEL_READ");
    }

    auto write_type = channel.get_write_export_type();
    if (write_type != NANOTUBE_CHANNEL_TYPE_NONE) {
      out << formatv("  nanotube_channel_export(channels[{0}], "
                     "{1}, {2});\n", channel_index,
                     get_enum_name(write_type),
                     "NANOTUBE_CHANNEL_WRITE");
    }

    out << "\n";
  }

  // Write the file footer.
  out << (
    "  nanotube_thread_create(context, \"poll_thread\", poll_thread,"
    " nullptr, 0);\n"
    "}\n"
  );
}

void top_writer::output_prototype(raw_os_ostream &out,
                                  thread_id_t thread_id)
{
  out << formatv("void stage_{0}(", thread_id);

  assert(thread_id < m_setup_func.threads().size());
  thread_info &thread = m_setup_func.get_thread_info(thread_id);
  context_info &context = m_setup_func.get_context_info(thread.context_index());

  port_index_t num_ports = context.ports().size();
  for (port_index_t port_index=0; port_index < num_ports; port_index++) {
    stage_port &port = context.get_port(port_index);
    auto &channel = m_setup_func.get_channel_info(port.channel_index());
    auto elem_size = channel.get_elem_size();
    auto sideband_size = channel.get_sideband_size();
    auto sideband_signals_size = channel.get_sideband_signals_size();

    if (sideband_size != 0 || sideband_signals_size != 0) {
      out << formatv("{0}\n  hls::stream<ap_axiu<{1},{2},0,0> > &port{3}",
                     (port_index == 0 ? "" : ","),
                     (elem_size - sideband_size - sideband_signals_size) << 3, sideband_size << 3,
                     port_index);
    }
    else {
      out << formatv("{0}\n  hls::stream<bytes<{1}> > &port{2}",
                     (port_index == 0 ? "" : ","),
                     elem_size,
                     port_index);
    }
  }
  out << formatv(num_ports == 0 ? "void)" : ")");
}

void top_writer::output_c_string(raw_os_ostream &out, StringRef str)
{
  out << '"';
  for (char c: str) {
    switch(c) {
    case '"':
      out << "\\" "\"";
      break;

    case '\\':
      out << "\\" "\\";
      break;

    case '\n':
      out << "\\n";
      break;

    case '?':
      out << "\\?";
      break;

    default:
      if (isprint(c))
        out << c;
      else
        out << formatv("\\x{0,0+2:x-}", uint8_t(c));
      break;
    }
  }
  out << '"';
}

///////////////////////////////////////////////////////////////////////////

global_var_writer::global_var_writer(const DataLayout *dl,
                                     setup_func *setup,
                                     raw_os_ostream *out):
  m_data_layout(dl),
  m_setup_func(setup),
  m_out(out)
{
}

void global_var_writer::write(value_id_t id, GlobalVariable *var)
{
  LLVM_DEBUG(dbgv("Writing state s{0} for {1}\n", id, *var););

  auto *ty = var->getType()->getElementType();
  *m_out << "static uint8_t s" << id << "["
         << m_data_layout->getTypeStoreSize(ty) << "] = {";

  bool is_big_endian = m_data_layout->isBigEndian();
  auto alloc = m_setup_func->find_alloc_of_var(var);
  auto base_ptr = m_setup_func->get_alloc_base(alloc);
  auto end_ptr = m_setup_func->get_alloc_end_of_ptr(base_ptr);

  LLVM_DEBUG(dbgv("Writing range {0}..{1}.\n", base_ptr, end_ptr););

  auto end_it = m_setup_func->memory_at(end_ptr);
  int counter = 0;
  for (auto it = m_setup_func->memory_at(base_ptr); it!=end_it; ++it) {
    uint8_t byte_val = 0xff;
    if (it->write_value.is_int()) {
      auto write_value = it->write_value.get_int();
      auto bit_width = write_value.getBitWidth();
      uint64_t shift = (is_big_endian
                        ? bit_width - 8*(it->write_offset+1)
                        : 8*it->write_offset);
      uint64_t width = std::min(uint64_t(8), bit_width-shift);
      byte_val = write_value.extractBits(width, shift).getLimitedValue();

    } else if (it->write_value.is_memset()) {
      byte_val = it->write_value.get_memset();
    }

    LLVM_DEBUG(dbgv("  Value {0:x+} offset {1} is {2:x+2}.\n",
                    it->write_value, it->write_offset, byte_val););

    // Output the value with a newline if required.
    if (counter == 0) {
      *m_out << "\n  0x";
      counter = 16;
    } else {
      *m_out << " 0x";
      --counter;
    }
    *m_out << formatv("{0:x-2},", byte_val);
  }
  *m_out << "\n};\n";
}

///////////////////////////////////////////////////////////////////////////

stage_writer::stage_writer(top_writer &top,
                           context_info &context,
                           thread_id_t thread_id,
                           std::ostream &out):
  m_top(top),
  m_setup_func(top.get_setup_func()),
  m_thread_id(thread_id),
  m_out(out),
  m_args(m_setup_func.get_thread_info(thread_id).args()),
  m_context(context),
  m_data_layout(m_args.func->getParent()->getDataLayout()),
  m_entry_bb(nullptr)
{
  validate_hls_thread_function(*(m_args.func));
  m_entry_bb = &(m_args.func->getEntryBlock());
}

void stage_writer::write()
{
  write_preamble();
  write_state();
  write_function();
}

void stage_writer::write_preamble()
{
  // Output the file header.
  m_out << formatv("// Stage {0}\n", m_thread_id);
  m_out << formatv("// Thread name:    {0}\n", m_args.name);
  m_out << formatv("// Thread function {0}\n", m_args.func->getName());

  port_index_t num_ports = m_context.ports().size();
  for (port_index_t port_index = 0; port_index < num_ports; port_index++) {
    const stage_port &port = m_context.get_port(port_index);
    m_out << formatv("//   Port {0} {1} channel {2}\n",
                     port_index, (port.is_read() ? "reads " : "writes"),
                     port.channel_index());
  }

  // Include some header files.
  m_out << ( "#include \"ap_int.h\"\n"
             "#include \"hls_stream.h\"\n"
             "#include \"stages.hh\"\n"
             "#include <cassert>\n"
             "#include <cstdint>\n"
             "#include <cstring>\n"
             "\n" );
}

void stage_writer::write_state()
{
  Function *func = m_args.func;
  for (BasicBlock &bb: *func) {
    for (Instruction &insn: bb) {
      auto iid = get_intrinsic(&insn);
      if (intrinsic_is_nop(iid))
        continue;
      for (Value *operand: insn.operands()) {
        if (operand->getType()->isPointerTy())
          write_state_for_operand(operand);
      }
    }
  }
  if (!m_static_var_ids.empty())
    m_out << "\n";
}

void stage_writer::write_state_for_operand(Value *pointer)
{
  // Find the GlobalVariable being referenced if any.
  Value *base = pointer->stripInBoundsConstantOffsets();
  auto *var = dyn_cast<GlobalVariable>(base);
  if (var == nullptr)
    return;

  // Assign an ID for the static variable unless one has already
  // been assigned.
  auto ins = m_static_var_ids.insert(
    std::make_pair(base, m_static_var_ids.size()));
  if (!ins.second)
    return;

  // Write the variable.
  global_var_writer(&m_data_layout, &m_setup_func, &m_out)
    .write(ins.first->second, var);
}

void stage_writer::write_function()
{
  // Begin the function definition.
  m_top.output_prototype(m_out, m_thread_id);
  m_out << ( "\n"
             "{\n" );

  write_declarations();
  write_pragmas();
  write_body();

  // End the function definition.
  m_out << "}\n";
}

void stage_writer::write_decl_type(Type *type, unsigned depth)
{
  // Handle pointer types.
  if (type->isPointerTy()) {
    // Output the declaration.
    m_out << "uint8_t *";
    return;
  }

  // Handle integer types.
  if (type->isIntegerTy()) {
    auto *int_type = cast<llvm::IntegerType>(type);

    // Output the declaration.
    m_out << formatv("ap_uint<{0}> ", int_type->getBitWidth());
    return;
  }

  // Handle struct types.
  if (type->isStructTy()) {
    if (depth > 0) {
      report_fatal_errorv("Encountered nested struct type: {0}",
                          *type);
    }

    auto *struct_ty = cast<StructType>(type);
    m_out << "struct { ";
    unsigned count = struct_ty->getNumElements();
    for( unsigned i = 0; i < count; i++ ) {
      Type *elem_type = struct_ty->getElementType(i);
      write_decl_type(elem_type, depth+1);
      m_out << formatv("m{0}; ", i);
    }
    m_out << "}";
    return;
  }

  report_fatal_errorv("Unsupported expression type {0}", *type);
}

static uint64_t get_alloca_size(const DataLayout &dl,
                                const AllocaInst &alloca)
{
  Type *elem_type = alloca.getAllocatedType();
  uint64_t elem_size = dl.getTypeAllocSize(elem_type);
  const Value *elem_count_val = alloca.getArraySize();

  auto *elem_count_ci = dyn_cast<ConstantInt>(elem_count_val);
  if ( elem_count_ci == nullptr )
    report_fatal_errorv("Non-constant element count in alloca: {0}",
                        alloca);

  return elem_count_ci->getLimitedValue() * elem_size;
}

void stage_writer::write_decl_for_alloca(const AllocaInst *alloca)
{
  if (!m_is_entry_bb)
    report_fatal_errorv("Alloca outside the entry block {0}",
                        *alloca);

  // Declare a buffer for the alloca.
  uint64_t buffer_size = get_alloca_size(m_data_layout, *alloca);

  LLVM_DEBUG(
    m_out << formatv("  // {0}\n", *alloca);
  );
  m_out << formatv("  uint8_t v{0}[{1}]", m_next_value_id,
                   buffer_size);

  // Set the alignment if it is not equal to one.
  if ( alloca->getAlignment() > 1 )
    m_out << formatv("__attribute__((aligned({0})))",
                     alloca->getAlignment());
  m_out << ";\n";
  m_out << formatv("#pragma HLS array_partition variable=v{0}"
                   " complete\n", m_next_value_id);

}

void stage_writer::write_decl_for_insn(const Instruction *insn)
{
  Type *type = insn->getType();

  // Ignore void instructions.
  if ( type->isVoidTy() )
    return;

  auto opcode = insn->getOpcode();
  if (opcode == Instruction::Alloca) {
    write_decl_for_alloca(cast<AllocaInst>(insn));
  } else {
    auto iid = get_intrinsic(insn);
    if (iid == Intrinsics::llvm_stacksave)
      return;

    // Ignore pointers which will be unwrapped by write_operand.
    if (type->isPointerTy()) {
      const Value *base = insn->stripInBoundsConstantOffsets();
      if (base != insn)
        return;
    }

    m_out << "  ";
    write_decl_type(type);
    m_out << formatv("v{0};\n", m_next_value_id);
  }

  m_local_var_ids.insert(std::make_pair(insn, m_next_value_id));
  m_next_value_id++;
}

void stage_writer::write_declarations()
{
  // Declare a variable for each port.
  port_index_t num_ports = m_context.ports().size();
  for (port_index_t port_index=0; port_index < num_ports; port_index++) {
    stage_port &port = m_context.get_port(port_index);
    auto &channel = m_setup_func.get_channel_info(port.channel_index());
    auto elem_size = channel.get_elem_size();
    auto sideband_size = channel.get_sideband_size();
    auto sideband_signals_size = channel.get_sideband_signals_size();

    if (sideband_size != 0 || sideband_signals_size != 0) {
      m_out << formatv("  ap_axiu<{0},{1},0,0> port{2}_data;\n",
                       (elem_size - sideband_size - sideband_signals_size) << 3, sideband_size << 3,
                       port_index);
      /* NANO-411: This removes the warning in cosim but prevents compilation for hw,
       * so backing out for now */
      /*m_out << formatv("#pragma HLS array_partition variable=port{0}_data complete\n",
                       port_index);*/
    }
    else {
      m_out << formatv("  bytes<{0}> port{1}_data;\n",
                       elem_size, port_index);
    }
  }

  // Declare a variable for every instruction which produces a value.
  for ( const BasicBlock &bb: *m_args.func ) {
    m_is_entry_bb = (&bb == m_entry_bb);
    for ( const Instruction &insn: bb ) {
      write_decl_for_insn(&insn);
    }
  }
}

void stage_writer::write_pragmas()
{
  m_out << ( "\n"
             "#pragma HLS pipeline II=1\n");
  m_out << formatv("#pragma HLS interface ap_ctrl_none port=return\n");

  port_index_t num_ports = m_context.ports().size();
  for (port_index_t port_index=0; port_index < num_ports; port_index++) {
    // Find the port ID and stage_port.
    const stage_port &port = m_context.get_port(port_index);
    setup_func &setup = m_top.get_setup_func();
    channel_info &channel = setup.get_channel_info(port.channel_index());
    auto sideband_size = channel.get_sideband_size();
    auto sideband_signals_size = channel.get_sideband_signals_size();

    m_out << formatv("#pragma HLS interface axis port=port{0}\n",
                     port_index);
    /* Aggregate only appropriate if no axis sideband signals */
    if (sideband_size == 0 && sideband_signals_size == 0) {
      m_out << "#if defined(NANOTUBE_USING_VIVADO_HLS)\n";
      m_out << formatv("#pragma HLS data_pack variable=port{0}\n",
                       port_index);
      m_out << "#else // defined(NANOTUBE_USING_VIVADO_HLS)\n";
      m_out << formatv("#pragma HLS aggregate variable=port{0}\n",
                       port_index);
      m_out << formatv("#pragma HLS aggregate variable=port{0}_data\n",
                       port_index);
      m_out << "#endif // defined(NANOTUBE_USING_VIVADO_HLS)\n";
    }
  }

  // Write pragmas for static variables.
  for (size_t index=0; index<m_static_var_ids.size(); index++) {
    m_out << formatv("#pragma HLS array_partition variable=s{0}"
                     " complete\n", index);
  }
}

void stage_writer::write_body()
{
  for ( const BasicBlock &bb: *m_args.func ) {
    m_out << "\n";

    // Output the label.
    if ( !bb.use_empty() ) {
      // Allocate an ID if there is not already one for this basic block.
      auto ins = m_label_ids.insert(
        std::make_pair(&bb, m_label_ids.size()));

      m_out << formatv("L{0}:\n", ins.first->second);
    }

    LLVM_DEBUG(
      m_out << "  // ";
      bb.printAsOperand(m_out, false);
      m_out << ":\n";
      );

    // Output the instructions.
    for ( const Instruction &insn: bb ) {
      LLVM_DEBUG(
        m_out << formatv("  // {0}\n", insn);
        );
      auto opcode = insn.getOpcode();
      switch ( opcode ) {
      case Instruction::Alloca:
        // Already handled.
        break;

      case Instruction::Add:
        write_binary_op("+", cast<BinaryOperator>(insn));
        break;

      case Instruction::And:
        write_binary_op("&", cast<BinaryOperator>(insn));
        break;

      case Instruction::AShr:
        write_binary_op(">>", cast<BinaryOperator>(insn), true);
        break;

      case Instruction::BitCast:
        // Pointer bitcasts are handled on demand.  Non-pointer
        // bitcasts are not supported.
        if (!insn.getType()->isPointerTy())
          report_fatal_errorv("Unsupported non-pointer bitcast in"
                              " thread body: {0}", insn);
        break;

      case Instruction::Br:
        write_branch(cast<BranchInst>(insn));
        break;

      case Instruction::Call:
        write_call(cast<CallBase>(insn));
        break;

      case Instruction::ExtractValue:
        write_extract_value(cast<ExtractValueInst>(insn));
        break;

      case Instruction::GetElementPtr:
        write_gep(cast<GetElementPtrInst>(insn));
        break;

      case Instruction::ICmp:
        write_icmp(cast<ICmpInst>(insn));
        break;

      case Instruction::Load:
        write_load(cast<LoadInst>(insn));
        break;

      case Instruction::Mul:
        write_binary_op("*", cast<BinaryOperator>(insn));
        break;

      case Instruction::LShr:
        write_binary_op(">>", cast<BinaryOperator>(insn));
        break;

      case Instruction::Or:
        write_binary_op("|", cast<BinaryOperator>(insn));
        break;

      case Instruction::PHI:
        break;

      case Instruction::Ret:
        m_out << "  return;\n";
        break;

      case Instruction::SDiv:
        write_binary_op("/", cast<BinaryOperator>(insn), true);
        break;

      case Instruction::Select:
        write_select(cast<SelectInst>(insn));
        break;

      case Instruction::SExt:
        write_cast(cast<CastInst>(insn), true);
        break;

      case Instruction::Shl:
        write_binary_op("<<", cast<BinaryOperator>(insn));
        break;

      case Instruction::SRem:
        write_binary_op("%", cast<BinaryOperator>(insn), true);
        break;

      case Instruction::Store:
        write_store(cast<StoreInst>(insn));
        break;

      case Instruction::Sub:
        write_binary_op("-", cast<BinaryOperator>(insn));
        break;

      case Instruction::Switch:
        write_switch(cast<SwitchInst>(insn));
        break;

      case Instruction::Trunc:
        write_cast(cast<CastInst>(insn), false);
        break;

      case Instruction::UDiv:
        write_binary_op("/", cast<BinaryOperator>(insn), false);
        break;

      case Instruction::Unreachable:
        m_out << "  assert(false);\n"
              << "  return;\n";
        break;

      case Instruction::URem:
        write_binary_op("%", cast<BinaryOperator>(insn), false);
        break;

      case Instruction::Xor:
        write_binary_op("^", cast<BinaryOperator>(insn));
        break;

      case Instruction::ZExt:
        write_cast(cast<CastInst>(insn), false);
        break;

      default:
        report_fatal_errorv("Unsupported instruction in thread body: {0}",
                            insn);
      }
    }
  }
}

void stage_writer::write_binary_op(const std::string &op,
                                   const BinaryOperator &insn,
                                   bool is_signed)
{
  m_out << "  ";
  write_operand(insn, insn);
  m_out << " = ";
  const Value &op0 = *insn.getOperand(0);
  if (is_signed)
    m_out << formatv("ap_int<{0}>(",
                     op0.getType()->getIntegerBitWidth());
  write_operand(insn, op0);
  if (is_signed)
    m_out << ')';
  m_out << ' ' << op << ' ';
  const Value &op1 = *insn.getOperand(1);
  if (is_signed)
    m_out << formatv("ap_int<{0}>(",
                     op1.getType()->getIntegerBitWidth());
  write_operand(insn, op1);
  if (is_signed)
    m_out << ')';
  m_out << ";\n";
}

void stage_writer::write_branch(const BranchInst &insn)
{
  if ( insn.isUnconditional() ) {
    write_cfg_edge("  ", *insn.getParent(), *insn.getSuccessor(0));

  } else {
    m_out << "  if ( ";
    write_operand(insn, *insn.getCondition());
    m_out << " ) {\n";
    write_cfg_edge("    ", *insn.getParent(), *insn.getSuccessor(0));
    m_out << "  } else {\n";
    write_cfg_edge("    ", *insn.getParent(), *insn.getSuccessor(1));
    m_out << "  }\n";
  }
}

void stage_writer::write_bswap(const CallBase &insn)
{
  assert(insn.getNumArgOperands() == 1);

  const Value *val = insn.getArgOperand(0);
  unsigned int val_width = val->getType()->getIntegerBitWidth();

  if (val_width != 16 && val_width != 32)
    report_fatal_errorv("write_bswap: Unsupported bswap width: {0}", val_width);

  m_out <<"  ";
  write_operand(insn, insn, HLS_TYPE_INTEGER);
  m_out << " = ( ";
  for (unsigned int i=0; i<val_width; i+=8) {
    unsigned int j = (val_width-8)-i;

    if (i != 0)
      m_out << " |\n    ";

    m_out << "( (";
    write_operand(insn, *val, HLS_TYPE_INTEGER);

    if (i <= j)
      m_out << " << " << (j-i);
    else
      m_out << " >> " << (i-j);

    m_out << ") & (ap_int<" << val_width << ">(0xff) << "
          << j << ") )";
  }
  m_out << " );\n";
}

void stage_writer::write_call(const CallBase &insn)
{
  auto iid = get_intrinsic(&insn);
  switch (iid) {
  case Intrinsics::channel_try_read:
    write_channel_call(insn, true);
    break;

  case Intrinsics::channel_write:
    write_channel_call(insn, false);
    break;

  case Intrinsics::thread_wait:
    break;

  case Intrinsics::debug_trace:
    write_debug_trace_call(insn);
    break;

  case Intrinsics::trace_buffer:
    write_trace_buffer_call(insn);
    break;

  case Intrinsics::llvm_bswap:
    write_bswap(insn);
    break;

  case Intrinsics::llvm_stackrestore:
  case Intrinsics::llvm_stacksave:
    // These can be ignored since all allocas are in the entry block.
    break;

  case Intrinsics::llvm_memcpy:
    write_memcpy(insn);
    break;

  case Intrinsics::llvm_memset:
    write_memset(insn);
    break;

  case Intrinsics::llvm_memcmp:
    write_memcmp(insn);
    break;

  case Intrinsics::llvm_usub_sat:
    write_usub_sat(insn);
    break;

  case Intrinsics::llvm_usub_with_overflow:
    write_usub_with_overflow(insn);
    break;

  default:
    if (!intrinsic_is_nop(iid))
      report_fatal_errorv("Unsupported call in thread function {0}",
                          insn);
  }
}

void stage_writer::write_cast(const CastInst &insn, bool is_signed)
{
  m_out << "  ";
  write_operand(insn, insn);
  m_out << " = ";
  const Value *operand = insn.getOperand(0);
  m_out << formatv("{0}<{1}>(",
                   (is_signed ? "ap_int" : "ap_uint"),
                   operand->getType()->getIntegerBitWidth());
  write_operand(insn, *operand);
  m_out << ");\n";
}

void
stage_writer::write_channel_call(const CallBase &insn, bool is_read)
{
  auto args = channel_read_write_args(&insn);
  auto *callee = insn.getCalledFunction();

  // Make sure the context argument is the function argument.
  auto *context_arg = dyn_cast<Argument>(args.context);
  if (context_arg == nullptr || context_arg->getArgNo() != 0)
    report_fatal_errorv("Invalid context argument to {0} : {1}",
                        callee->getName(), args.context);

  port_index_t port_index = m_context.get_port_index(args.channel_id, is_read);
  if (port_index == port_index_none)
    report_fatal_errorv("Failed to find channel ID {0} in context",
                        args.channel_id);

  // Find the port ID and stage_port.
  const stage_port &port = m_context.get_port(port_index);

  // Check the data length.
  setup_func &setup = m_top.get_setup_func();
  channel_info &channel = setup.get_channel_info(port.channel_index());
  auto elem_size = channel.get_elem_size();
  auto sideband_size = channel.get_sideband_size();
  auto sideband_signals_size = channel.get_sideband_signals_size();
  auto user_offset = elem_size - sideband_size - sideband_signals_size;
  auto keep_offset = elem_size - sideband_signals_size;
  auto strb_offset = elem_size - (sideband_signals_size - 1)/2 - 1;
  auto last_offset = elem_size - 1;

  if (elem_size != args.data_size)
    report_fatal_errorv("Data size mismatch ({0} != {1}) in channel {2}: {3}",
                        elem_size, args.data_size,
                        is_read ? "read" : "write", insn);

  // A channel read will write to the buffer.  A channel write will
  // read from the buffer.
  bool buf_is_written = is_read;
  check_mem_access(&insn, args.data, args.data_size, buf_is_written);

  /* This converts between the ap_axiu structured data type and our
   * internal flat byte array as follows:
   *
   * Start byte  | Length        | Field
   * -----------------------------------
   * 0           | data width       | Data
   * user_offset | sideband_size    | Sideband bytes (TUSER)
   * keep_offset | (data_width) / 8 | Sideband signal TKEEP
   * strb_offset | (data_width) / 8 | Sideband signal TSTRB
   * last_offset | 1                | Sideband signal TLAST
   *
   * If sideband_size is zero, TUSER is omitted, and the other fields are moved up
   * If sideband_signals_size is zero, TKEEP/TSTRB/TLAST are omitted
   */

  // Write the channel access.
  if (is_read) {
    m_out << "  ";
    write_operand(insn, insn);
    m_out << formatv(" = port{0}.read_nb(port{0}_data);\n",
                     port_index);
    m_out << "  nanotube_memcpy(";
    write_operand(insn, *(args.data), HLS_TYPE_POINTER);
    if( sideband_size !=0 || sideband_signals_size != 0 )
      m_out << formatv(", &port{0}_data.data, {1});\n",
                       port_index, user_offset);
    else
      m_out << formatv(", port{0}_data.data, {1});\n",
                       port_index, user_offset);

    if (sideband_size != 0) {
      m_out << "  nanotube_memcpy(";
      write_operand(insn, *(args.data), HLS_TYPE_POINTER);
      m_out << formatv("+{0}", user_offset);
      m_out << formatv(", &port{0}_data.user, {1});\n",
                       port_index, sideband_size);
    }

    if (sideband_signals_size != 0) {
      m_out << "  nanotube_memcpy(";
      write_operand(insn, *(args.data), HLS_TYPE_POINTER);
      m_out << formatv("+{0}", keep_offset);
      m_out << formatv(", &port{0}_data.keep, {1});\n",
                       port_index, (sideband_signals_size-1)/2);

      m_out << "  nanotube_memcpy(";
      write_operand(insn, *(args.data), HLS_TYPE_POINTER);
      m_out << formatv("+{0}", strb_offset);
      m_out << formatv(", &port{0}_data.strb, {1});\n",
                       port_index, (sideband_signals_size-1)/2);

      /* last is just one bit, so do copy manually */
      m_out << "  (";
      write_operand(insn, *(args.data), HLS_TYPE_POINTER);
      m_out << formatv("+{0})[0]", last_offset);
      m_out << formatv(" = port{0}_data.last;\n", port_index);
    }
  } else {
    if (sideband_size != 0 || sideband_signals_size != 0)
      m_out << formatv("  nanotube_memcpy(&port{0}_data.data, ",
                       port_index);
    else
      m_out << formatv("  nanotube_memcpy(port{0}_data.data, ",
                       port_index);
    write_operand(insn, *(args.data), HLS_TYPE_POINTER);
    m_out << formatv(", {0});\n", user_offset);
    
    if (sideband_size != 0) {
      m_out << formatv("  nanotube_memcpy(&port{0}_data.user, ",
                       port_index);
      write_operand(insn, *(args.data), HLS_TYPE_POINTER);
      m_out << formatv("+{0}", user_offset);
      m_out << formatv(", {0});\n", sideband_size);
    }

    if (sideband_signals_size != 0) {
      m_out << formatv("  nanotube_memcpy(&port{0}_data.keep, ",
                       port_index);
      write_operand(insn, *(args.data), HLS_TYPE_POINTER);
      m_out << formatv("+{0}", keep_offset);
      m_out << formatv(", {0});\n", (sideband_signals_size-1)/2);

      m_out << formatv("  nanotube_memcpy(&port{0}_data.strb, ",
                       port_index);
      write_operand(insn, *(args.data), HLS_TYPE_POINTER);
      m_out << formatv("+{0}", strb_offset);
      m_out << formatv(", {0});\n", (sideband_signals_size-1)/2);

      /* last is just one bit, so do nanotube_memcpy() manually */
      m_out << formatv("  port{0}_data.last = (", port_index);
      write_operand(insn, *(args.data), HLS_TYPE_POINTER);
      m_out << formatv("+{0})[0] & 1;\n", last_offset);
    }

    m_out << formatv("  port{0}.write(port{0}_data);\n", port_index);
  }
}

void
stage_writer::write_debug_trace_call(const CallBase &insn)
{
  auto args = debug_trace_args(&insn);
  m_out << "  nanotube_debug_trace(";
  write_operand(insn, *args.id);
  m_out << ", ";
  write_operand(insn, *args.value);
  m_out << ");\n";
}

void
stage_writer::write_extract_value(const ExtractValueInst &insn)
{
  auto *aggr = insn.getAggregateOperand();
  unsigned num_idx = insn.getNumIndices();

  m_out << "  ";
  write_operand(insn, insn);
  m_out << " = ";
  write_operand(insn, *aggr, HLS_TYPE_STRUCT);
  for (unsigned i = 0; i < num_idx; i++) {
    unsigned idx = insn.getIndices()[i];
    m_out << formatv(".m{0}", idx);
  }
  m_out << ";\n";
}

void
stage_writer::write_trace_buffer_call(const CallBase &insn)
{
  assert(insn.getNumArgOperands() == 3);
  const Value *id = insn.getArgOperand(0);
  const Value *buffer = insn.getArgOperand(1);
  const Value *size = insn.getArgOperand(2);

  check_mem_access(&insn, buffer, size, false);

  m_out << "  nanotube_trace_buffer(";
  write_operand(insn, *id);
  m_out << ", ";
  write_operand(insn, *buffer, HLS_TYPE_POINTER);
  m_out << ", ";
  write_operand(insn, *size);
  m_out << ");\n";
}

void stage_writer::write_gep(const GetElementPtrInst &insn)
{
  // GEP instructions with all constant indices are handled on demand.
  if (insn.isInBounds() && insn.hasAllConstantIndices())
    return;

  LLVM_DEBUG(dbgs() << formatv("Handling GEP: {0}\n", insn));

  // Start the assignment.
  m_out << "  ";
  write_operand(insn, insn, HLS_TYPE_POINTER);
  m_out << " = (";
  write_operand(insn, *insn.getPointerOperand(), HLS_TYPE_POINTER);

  // Add terms for each of the indices.
  for (auto GTI=llvm::gep_type_begin(insn), GTE=llvm::gep_type_end(insn);
       GTI!=GTE; GTI++) {
    // The index value.
    const Value *val = GTI.getOperand();
    auto *const_int = dyn_cast<ConstantInt>(val);

    // The meaning of the index depends on whether this is a struct
    // type or an index type.
    if (StructType *struct_type = GTI.getStructTypeOrNull()) {
      if (const_int == nullptr)
        report_fatal_errorv("Struct index is not a constant: {0}",
                            insn);

      // Find the offset of the element being indexed.
      unsigned element_index = const_int->getZExtValue();
      auto *layout = m_data_layout.getStructLayout(struct_type);
      uint64_t offset = layout->getElementOffset(element_index);
      LLVM_DEBUG(
        dbgs() << formatv("  Struct type type={0}, index={1}, offset={2}\n",
                          *struct_type, element_index, offset));

      // Output it if it is not zero.
      if (offset != 0)
        m_out << formatv(" + {0}", offset);

    } else {
      // The type being indexed.
      Type *type = GTI.getIndexedType();
      LLVM_DEBUG(
        dbgs() << formatv("  Index type type={0}, val={1}\n",
                          *type, *val));

      // Multiply the size of the type by the index.
      if (const_int == nullptr || !const_int->isZero()) {
        m_out << formatv(" + ({0} * ",
                         m_data_layout.getTypeAllocSize(type));
        write_operand(insn, *val);
        m_out << ")";
      }
    }
  }

  // End the assignment.
  m_out << ");\n";
}

void stage_writer::write_icmp(const ICmpInst &insn)
{
  auto predicate = insn.getPredicate();
  switch (predicate) {
  case CmpInst::ICMP_EQ:
    write_icmp_body(" == ", false, insn);
    break;

  case CmpInst::ICMP_NE:
    write_icmp_body(" != ", false, insn);
    break;

  case CmpInst::ICMP_ULE:
    write_icmp_body(" <= ", false, insn);
    break;

  case CmpInst::ICMP_ULT:
    write_icmp_body(" < ", false, insn);
    break;

  case CmpInst::ICMP_UGE:
    write_icmp_body(" >= ", false, insn);
    break;

  case CmpInst::ICMP_UGT:
    write_icmp_body(" > ", false, insn);
    break;

  case CmpInst::ICMP_SLE:
    write_icmp_body(" <= ", true, insn);
    break;

  case CmpInst::ICMP_SLT:
    write_icmp_body(" < ", true, insn);
    break;

  case CmpInst::ICMP_SGE:
    write_icmp_body(" >= ", true, insn);
    break;

  case CmpInst::ICMP_SGT:
    write_icmp_body(" > ", true, insn);
    break;

  default:
    report_fatal_errorv("Unsupported predicate {0} in {1}",
                        CmpInst::getPredicateName(predicate),
                        insn);
  }
}

void stage_writer::write_icmp_body(const std::string &op_str,
                                   bool is_signed,
                                   const ICmpInst &insn)
{
  Type *type = insn.getOperand(0)->getType();
  unsigned int bit_width = type->getIntegerBitWidth();
  m_out << "  ";
  write_operand(insn, insn);
  m_out << " = ";
  if (is_signed)
    m_out << formatv("((ap_uint<{0}>(1) << {1}) ^ ",
                     bit_width, bit_width-1);
  write_operand(insn, *insn.getOperand(0));
  if (is_signed)
    m_out << ")";
  m_out << op_str;
  if (is_signed)
    m_out << formatv("((ap_uint<{0}>(1) << {1}) ^ ",
                     bit_width, bit_width-1);
  write_operand(insn, *insn.getOperand(1));
  if (is_signed)
    m_out << ")";
  m_out << ";\n";
}

void stage_writer::write_load(const LoadInst &insn)
{
  // Make sure the type is supported.
  Type *type = insn.getType();
  auto int_ty = dyn_cast<IntegerType>(type);
  if (int_ty == nullptr) {
    report_fatal_errorv("Unsupported type in load instruction: {0}",
                        insn);
  }

  m_out << "  ";
  write_operand(insn, insn);
  m_out << " = ";

  unsigned num_bits = int_ty->getBitWidth();
  unsigned num_bytes = m_data_layout.getTypeStoreSize(type);
  bool is_big_endian = m_data_layout.isBigEndian();
  const Value *ptr = insn.getPointerOperand();
  check_mem_access(&insn, ptr, num_bytes, false);
  for (unsigned i=0; i<num_bytes; i++) {
    unsigned shift = i*8;
    if (is_big_endian)
      shift = num_bytes-8-shift;
    if (i != 0)
      m_out << " |\n    ";
    m_out << "(ap_uint<" << num_bits << ">(";
    write_operand(insn, *ptr, HLS_TYPE_POINTER);
    m_out << "[" << i << "]) << " << shift << ")";
  }
  m_out << ";\n";
}

void stage_writer::write_memcpy(const CallBase &insn)
{
  assert(insn.getNumArgOperands() == 4);
  const Value *dest = insn.getArgOperand(0);
  const Value *src = insn.getArgOperand(1);
  const Value *size = insn.getArgOperand(2);
  const Value *is_volatile = insn.getArgOperand(3);

  const auto *is_volatile_const = dyn_cast<ConstantInt>(is_volatile);
  if (is_volatile_const == nullptr)
    report_fatal_errorv("Invalid expression for is_volatile: {0}", insn);
  if (!is_volatile_const->isZeroValue())
    report_fatal_errorv("Volatile memcpy is not supported: {0}", insn);

  check_mem_access(&insn, dest, size, true);
  check_mem_access(&insn, src, size, false);

  m_out << "  nanotube_memcpy(";
  write_operand(insn, *dest, HLS_TYPE_POINTER);
  m_out << ", ";
  write_operand(insn, *src, HLS_TYPE_POINTER);
  m_out << ", ";
  write_operand(insn, *size, HLS_TYPE_INTEGER);
  m_out << ");\n";
}

void stage_writer::write_memset(const CallBase &insn)
{
  assert(insn.getNumArgOperands() == 4);
  const Value *dest = insn.getArgOperand(0);
  const Value *val = insn.getArgOperand(1);
  const Value *size = insn.getArgOperand(2);
  const Value *is_volatile = insn.getArgOperand(3);

  const auto *is_volatile_const = dyn_cast<ConstantInt>(is_volatile);
  if (is_volatile_const == nullptr)
    report_fatal_errorv("Invalid expression for is_volatile: {0}", insn);
  if (!is_volatile_const->isZeroValue())
    report_fatal_errorv("Volatile memset is not supported: {0}", insn);

  check_mem_access(&insn, dest, size, true);

  m_out << "  memset(";
  write_operand(insn, *dest, HLS_TYPE_POINTER);
  m_out << ", ";
  write_operand(insn, *val, HLS_TYPE_INTEGER);
  m_out << ", ";
  write_operand(insn, *size, HLS_TYPE_INTEGER);
  m_out << ");\n";
}

void stage_writer::write_memcmp(const CallBase &insn)
{
  assert(insn.getNumArgOperands() == 3);
  const Value *x = insn.getArgOperand(0);
  const Value *y = insn.getArgOperand(1);
  const Value *size = insn.getArgOperand(2);

  check_mem_access(&insn, x, size, false);
  check_mem_access(&insn, y, size, false);

  m_out << "  ";
  write_operand(insn, insn, HLS_TYPE_INTEGER);
  m_out << " = nanotube_memcmp(";
  write_operand(insn, *x, HLS_TYPE_POINTER);
  m_out << ", ";
  write_operand(insn, *y, HLS_TYPE_POINTER);
  m_out << ", ";
  write_operand(insn, *size, HLS_TYPE_INTEGER);
  m_out << ");\n";
}

void stage_writer::write_select(const SelectInst &insn)
{
  m_out << "  ";
  write_operand(insn, insn, HLS_TYPE_ANY);
  m_out << " = (";
  write_operand(insn, *(insn.getCondition()));
  m_out << " ? ";
  write_operand(insn, *(insn.getTrueValue()), HLS_TYPE_ANY);
  m_out << " : ";
  write_operand(insn, *(insn.getFalseValue()), HLS_TYPE_ANY);
  m_out << ");\n";
}

void stage_writer::write_store(const StoreInst &insn)
{
  const Value *data = insn.getValueOperand();
  Type *ty = data->getType();
  auto int_ty = dyn_cast<IntegerType>(ty);
  if (int_ty == nullptr) {
    report_fatal_errorv("Unsupported type in store instruction: {0}",
                        insn);
  }

  unsigned num_bytes = m_data_layout.getTypeStoreSize(ty);
  bool is_big_endian = m_data_layout.isBigEndian();

  const Value *ptr = insn.getPointerOperand();
  check_mem_access(&insn, ptr, num_bytes, true);
  for (unsigned i=0; i<num_bytes; i++) {
    unsigned shift = i*8;
    if (is_big_endian)
      shift = num_bytes-8-shift;

    m_out << "  ";
    write_operand(insn, *ptr, HLS_TYPE_POINTER);
    m_out << "[" << i << "] = (";
    write_operand(insn, *data);
    m_out << " >> " << shift << ");\n";
  }
}

void stage_writer::write_switch(const SwitchInst &insn)
{
  // Make sure the expression type is integer.
  Type *ty = insn.getCondition()->getType();
  auto int_ty = dyn_cast<IntegerType>(ty);
  if (int_ty == nullptr) {
    report_fatal_errorv("Non-integer type in switch statement: {0}",
                        insn);
  }

  // Get the bit width.
  unsigned num_bits = int_ty->getBitWidth();
  if (num_bits > 64) {
    report_fatal_errorv("Switch statement type is too wide: {0}",
                        insn);
  }

  // Round up to a supported size.
  num_bits = ( num_bits <= 8 ? 8 :
               ( num_bits <= 16 ? 16 :
                 ( num_bits <= 32 ? 32 :
                   64 ) ) );

  // Write the switch statement.
  m_out << "  switch(uint" << num_bits << "_t(";
  write_operand(insn, *insn.getCondition());
  m_out << ")) {\n";
  for (auto &case_info: insn.cases()) {
    m_out << "  case uint" << num_bits << "_t(";
    case_info.getCaseValue()->printAsOperand(m_out, false);
    m_out << "):\n";
    write_cfg_edge("    ", *insn.getParent(),
                   *case_info.getCaseSuccessor());
  }
  m_out << "  default:\n";
  write_cfg_edge("    ", *insn.getParent(), *insn.getDefaultDest());
  m_out << "  }\n";
}

void stage_writer::write_usub_sat(const CallBase &insn)
{
  // Validate and extract the types involved.
  auto *ret_ty = cast<IntegerType>(insn.getType());

  assert(insn.arg_size() == 2);
  auto *op0 = insn.getArgOperand(0);
  auto *op0_ty = cast<IntegerType>(op0->getType());
  assert(op0_ty == ret_ty);
  auto *op1 = insn.getArgOperand(1);
  auto *op1_ty = cast<IntegerType>(op1->getType());
  assert(op1_ty == ret_ty);

  // Write: ret = op0 >= op1 ? ap_uint<n>(op0 - op1) : ap_uint<n>(0);
  m_out << "  ";
  write_operand(insn, insn);
  m_out << " = ";
  write_operand(insn, *op0);
  m_out << " >= ";
  write_operand(insn, *op1);
  m_out << formatv(" ? ap_uint<{0}>(", ret_ty->getBitWidth());
  write_operand(insn, *op0);
  m_out << " - ";
  write_operand(insn, *op1);
  m_out << formatv(") : ap_uint<{0}>(0);\n", ret_ty->getBitWidth());
}

void stage_writer::write_usub_with_overflow(const CallBase &insn)
{
  // Validate and extract the types involved.
  auto *ret_ty = cast<StructType>(insn.getType());
  assert(ret_ty->getNumElements() == 2);
  auto *ret0_ty = cast<IntegerType>(ret_ty->getElementType(0));
  auto *ret1_ty = cast<IntegerType>(ret_ty->getElementType(1));
  assert(ret1_ty->getBitWidth() == 1);

  assert(insn.arg_size() == 2);
  auto *op0 = insn.getArgOperand(0);
  auto *op0_ty = cast<IntegerType>(op0->getType());
  assert(op0_ty == ret0_ty);
  auto *op1 = insn.getArgOperand(1);
  auto *op1_ty = cast<IntegerType>(op1->getType());
  assert(op1_ty == ret0_ty);

  // Write: ret.m0 = op0 - op1;
  m_out << "  ";
  write_operand(insn, insn, HLS_TYPE_STRUCT);
  m_out << ".m0 = ";
  write_operand(insn, *op0);
  m_out << " - ";
  write_operand(insn, *op1);
  m_out << ";\n";

  // Write: ret.m1 = op0 < op1;
  m_out << "  ";
  write_operand(insn, insn, HLS_TYPE_STRUCT);
  m_out << ".m1 = ";
  write_operand(insn, *op0);
  m_out << " < ";
  write_operand(insn, *op1);
  m_out << ";\n";
}

void stage_writer::write_cfg_edge(StringRef indent, const BasicBlock &from_bb,
                                  const BasicBlock &to_bb)
{
  // Convert all the phi nodes to assignments on the basic block edge.
  for (const PHINode &phi: to_bb.phis()) {
    const Value *val = phi.getIncomingValueForBlock(&from_bb);
    m_out << indent;
    write_operand(phi, phi, HLS_TYPE_ANY);
    m_out << " = ";
    write_operand(phi, *val, HLS_TYPE_ANY);
    m_out << ";\n";
  }

  // Find or create an ID for the label.  The insert method does
  // exactly what I want.
  auto ins = m_label_ids.insert(
    std::make_pair(&to_bb, m_label_ids.size()));

  // Output the label.
  m_out << formatv("{0}goto L{1};\n", indent, ins.first->second);
}

void stage_writer::write_operand(const Instruction &insn, const Value &val_in,
                                 enum hls_type expected_type)
{
  const Value *val = &val_in;
  Type *type = val->getType();
  bool is_pointer = type->isPointerTy();
  if (is_pointer) {
    if (expected_type != HLS_TYPE_POINTER &&
        expected_type != HLS_TYPE_ANY)
      report_fatal_errorv("Invalid pointer arithmetic: {0}", insn);
  } else if (type->isIntegerTy()) {
    if (expected_type != HLS_TYPE_INTEGER &&
        expected_type != HLS_TYPE_ANY)
      report_fatal_errorv("Invalid integer arithmetic: {0}", insn);
  } else if (type->isStructTy()) {
    if (expected_type != HLS_TYPE_STRUCT &&
        expected_type != HLS_TYPE_ANY)
      report_fatal_errorv("Invalid use of struct: {0}", insn);

  } else {
    report_fatal_errorv("Invalid type for arithmetic: {0}", insn);
  }

  // Strip constant GEPs and bitcasts.
  unsigned int pointer_bits = m_data_layout.getTypeSizeInBits(type);
  auto buffer_offset = APInt(pointer_bits, 0);
  if (is_pointer) {
    // Get the buffer and offset.
    val = val->stripAndAccumulateInBoundsConstantOffsets
            (m_data_layout, buffer_offset);

    if (buffer_offset != 0)
      m_out << '(';
  }

  write_operand_base(insn, *val);

  if (buffer_offset != 0) {
    m_out << formatv("+{0})", buffer_offset);
  }
}

void stage_writer::write_operand_base(const Instruction &insn,
                                      const Value &val)
{
  const auto *const_int = dyn_cast<ConstantInt>(&val);
  if (const_int != nullptr) {
    IntegerType *t = const_int->getType();
    m_out << formatv("ap_uint<{0}>(", t->getBitWidth());
    const_int->printAsOperand(m_out, false);
    m_out << ")";
    return;
  }

  if (isa<Instruction>(val)) {
    auto it = m_local_var_ids.find(&val);
    if ( it == m_local_var_ids.end() )
      report_fatal_errorv("Failed to find local variable for {0}", val);
    m_out << 'v' << it->second;
    return;
  }

  if (isa<GlobalVariable>(val)) {
    auto it = m_static_var_ids.find(&val);
    if (it == m_static_var_ids.end())
      report_fatal_errorv("Failed to find static variable for {0}", val);
    m_out << "s" << it->second;
    return;
  }

  report_fatal_errorv("Unsupported variable {0} in {1}",
                      val, insn);
}

void stage_writer::check_mem_access(const Instruction *insn, const Value *base,
                                    uint64_t size, bool is_write)
{
  auto pointer_type = base->getType();
  auto pointer_bits = m_data_layout.getTypeSizeInBits(pointer_type);
  check_mem_access(insn, base, APInt(pointer_bits, size), is_write);
}

void stage_writer::check_mem_access(const Instruction *insn, const Value *base,
                                    const Value *size, bool is_write)
{
  auto size_const = dyn_cast<ConstantInt>(size);
  if (size_const == nullptr) {
    errs() << "Unknown size in " << *insn << "\n";
    return;
  }
  check_mem_access(insn, base, size_const->getValue(), is_write);
}

void stage_writer::check_mem_access(const Instruction *insn,
                                    const Value *base, APInt access_size,
                                    bool is_write)
{
  auto offset = APInt(access_size.getBitWidth(), 0);
  base = base->stripAndAccumulateInBoundsConstantOffsets
    (m_data_layout, offset);

  auto ptr_type = dyn_cast<PointerType>(base->getType());
  if (ptr_type == nullptr) {
    report_fatal_errorv("Access through a non-pointer {0} in {1}",
                        *base, *insn);
  }

  auto glb_var = dyn_cast<GlobalVariable>(base);
  if (glb_var != nullptr) {
    // Assign the variable to this thread so it cannot be used by
    // multiple threads.
    m_top.set_thread_of_var(*base, m_thread_id, is_write);

    auto var_type = glb_var->getValueType();
    auto var_size = m_data_layout.getTypeStoreSize(var_type);
    auto max_size = var_size - offset;
    if (offset.ugt(var_size) || access_size.ugt(max_size)) {
      errs() << formatv("WARNING: Out of bounds access in {0}\n",
                        *insn);
    }
    return;
  }

  auto alloca = dyn_cast<AllocaInst>(base);
  if (alloca != nullptr) {
    auto var_size = get_alloca_size(m_data_layout, *alloca);
    auto max_size = var_size - offset;
    if (offset.uge(var_size) || access_size.ugt(max_size)) {
      errs() << formatv("WARNING: Out of bounds access in {0}\n",
                        *insn);
    }
    return;
  }
}

///////////////////////////////////////////////////////////////////////////

Pass *nanotube::create_hls_printer(const std::string &output_directory,
                                   bool overwrite)
{
  return new hls_printer(output_directory, overwrite);
}

///////////////////////////////////////////////////////////////////////////
